1mod cache;
60
61#[cfg(not(feature = "boxed-trait"))]
62use std::future::Future;
63use std::{
64 any::{Any, TypeId},
65 borrow::Cow,
66 collections::{HashMap, HashSet},
67 hash::Hash,
68 sync::{
69 atomic::{AtomicBool, Ordering},
70 Arc, Mutex,
71 },
72 time::Duration,
73};
74
75pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
76use fnv::FnvHashMap;
77use futures_channel::oneshot;
78use futures_timer::Delay;
79use futures_util::future::BoxFuture;
80#[cfg(feature = "tracing")]
81use tracing::{info_span, instrument, Instrument};
82#[cfg(feature = "tracing")]
83use tracinglib as tracing;
84
85#[allow(clippy::type_complexity)]
86struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
87 use_cache_values: HashMap<K, T::Value>,
88 tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
89}
90
91struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
92 keys: HashSet<K>,
93 pending: Vec<(HashSet<K>, ResSender<K, T>)>,
94 cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
95 disable_cache: bool,
96}
97
98type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
99
100impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
101 fn new<C: CacheFactory>(cache_factory: &C) -> Self {
102 Self {
103 keys: Default::default(),
104 pending: Vec::new(),
105 cache_storage: cache_factory.create::<K, T::Value>(),
106 disable_cache: false,
107 }
108 }
109
110 fn take(&mut self) -> KeysAndSender<K, T> {
111 (
112 std::mem::take(&mut self.keys),
113 std::mem::take(&mut self.pending),
114 )
115 }
116}
117
118#[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
120pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
121 type Value: Send + Sync + Clone + 'static;
123
124 type Error: Send + Clone + 'static;
126
127 #[cfg(feature = "boxed-trait")]
129 async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
130
131 #[cfg(not(feature = "boxed-trait"))]
133 fn load(
134 &self,
135 keys: &[K],
136 ) -> impl Future<Output = Result<HashMap<K, Self::Value>, Self::Error>> + Send;
137}
138
139struct DataLoaderInner<T> {
140 requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
141 loader: T,
142}
143
144impl<T> DataLoaderInner<T> {
145 #[cfg_attr(feature = "tracing", instrument(skip_all))]
146 async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
147 where
148 K: Send + Sync + Hash + Eq + Clone + 'static,
149 T: Loader<K>,
150 {
151 let tid = TypeId::of::<K>();
152 let keys = keys.into_iter().collect::<Vec<_>>();
153
154 match self.loader.load(&keys).await {
155 Ok(values) => {
156 let mut request = self.requests.lock().unwrap();
158 let typed_requests = request
159 .get_mut(&tid)
160 .unwrap()
161 .downcast_mut::<Requests<K, T>>()
162 .unwrap();
163 let disable_cache = typed_requests.disable_cache || disable_cache;
164 if !disable_cache {
165 for (key, value) in &values {
166 typed_requests
167 .cache_storage
168 .insert(Cow::Borrowed(key), Cow::Borrowed(value));
169 }
170 }
171
172 for (keys, sender) in senders {
174 let mut res = HashMap::new();
175 res.extend(sender.use_cache_values);
176 for key in &keys {
177 res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
178 }
179 sender.tx.send(Ok(res)).ok();
180 }
181 }
182 Err(err) => {
183 for (_, sender) in senders {
184 sender.tx.send(Err(err.clone())).ok();
185 }
186 }
187 }
188 }
189}
190
191pub struct DataLoader<T, C = NoCache> {
195 inner: Arc<DataLoaderInner<T>>,
196 cache_factory: C,
197 delay: Duration,
198 max_batch_size: usize,
199 disable_cache: AtomicBool,
200 spawner: Box<dyn Fn(BoxFuture<'static, ()>) + Send + Sync>,
201}
202
203impl<T> DataLoader<T, NoCache> {
204 pub fn new<S, R>(loader: T, spawner: S) -> Self
206 where
207 S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
208 {
209 Self {
210 inner: Arc::new(DataLoaderInner {
211 requests: Mutex::new(Default::default()),
212 loader,
213 }),
214 cache_factory: NoCache,
215 delay: Duration::from_millis(1),
216 max_batch_size: 1000,
217 disable_cache: false.into(),
218 spawner: Box::new(move |fut| {
219 spawner(fut);
220 }),
221 }
222 }
223}
224
225impl<T, C: CacheFactory> DataLoader<T, C> {
226 pub fn with_cache<S, R>(loader: T, spawner: S, cache_factory: C) -> Self
228 where
229 S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
230 {
231 Self {
232 inner: Arc::new(DataLoaderInner {
233 requests: Mutex::new(Default::default()),
234 loader,
235 }),
236 cache_factory,
237 delay: Duration::from_millis(1),
238 max_batch_size: 1000,
239 disable_cache: false.into(),
240 spawner: Box::new(move |fut| {
241 spawner(fut);
242 }),
243 }
244 }
245
246 #[must_use]
248 pub fn delay(self, delay: Duration) -> Self {
249 Self { delay, ..self }
250 }
251
252 #[must_use]
258 pub fn max_batch_size(self, max_batch_size: usize) -> Self {
259 Self {
260 max_batch_size,
261 ..self
262 }
263 }
264
265 #[inline]
267 pub fn loader(&self) -> &T {
268 &self.inner.loader
269 }
270
271 pub fn enable_all_cache(&self, enable: bool) {
273 self.disable_cache.store(!enable, Ordering::SeqCst);
274 }
275
276 pub fn enable_cache<K>(&self, enable: bool)
278 where
279 K: Send + Sync + Hash + Eq + Clone + 'static,
280 T: Loader<K>,
281 {
282 let tid = TypeId::of::<K>();
283 let mut requests = self.inner.requests.lock().unwrap();
284 let typed_requests = requests
285 .get_mut(&tid)
286 .unwrap()
287 .downcast_mut::<Requests<K, T>>()
288 .unwrap();
289 typed_requests.disable_cache = !enable;
290 }
291
292 #[cfg_attr(feature = "tracing", instrument(skip_all))]
294 pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
295 where
296 K: Send + Sync + Hash + Eq + Clone + 'static,
297 T: Loader<K>,
298 {
299 let mut values = self.load_many(std::iter::once(key.clone())).await?;
300 Ok(values.remove(&key))
301 }
302
303 #[cfg_attr(feature = "tracing", instrument(skip_all))]
305 pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
306 where
307 K: Send + Sync + Hash + Eq + Clone + 'static,
308 I: IntoIterator<Item = K>,
309 T: Loader<K>,
310 {
311 enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
312 ImmediateLoad(KeysAndSender<K, T>),
313 StartFetch,
314 Delay,
315 }
316
317 let tid = TypeId::of::<K>();
318
319 let (action, rx) = {
320 let mut requests = self.inner.requests.lock().unwrap();
321 let typed_requests = requests
322 .entry(tid)
323 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
324 .downcast_mut::<Requests<K, T>>()
325 .unwrap();
326 let prev_count = typed_requests.keys.len();
327 let mut keys_set = HashSet::new();
328 let mut use_cache_values = HashMap::new();
329
330 if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
331 keys_set = keys.into_iter().collect();
332 } else {
333 for key in keys {
334 if let Some(value) = typed_requests.cache_storage.get(&key) {
335 use_cache_values.insert(key.clone(), value.clone());
337 } else {
338 keys_set.insert(key);
339 }
340 }
341 }
342
343 if !use_cache_values.is_empty() && keys_set.is_empty() {
344 return Ok(use_cache_values);
345 } else if use_cache_values.is_empty() && keys_set.is_empty() {
346 return Ok(Default::default());
347 }
348
349 typed_requests.keys.extend(keys_set.clone());
350 let (tx, rx) = oneshot::channel();
351 typed_requests.pending.push((
352 keys_set,
353 ResSender {
354 use_cache_values,
355 tx,
356 },
357 ));
358
359 if typed_requests.keys.len() >= self.max_batch_size {
360 (Action::ImmediateLoad(typed_requests.take()), rx)
361 } else {
362 (
363 if !typed_requests.keys.is_empty() && prev_count == 0 {
364 Action::StartFetch
365 } else {
366 Action::Delay
367 },
368 rx,
369 )
370 }
371 };
372
373 match action {
374 Action::ImmediateLoad(keys) => {
375 let inner = self.inner.clone();
376 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
377 let task = async move { inner.do_load(disable_cache, keys).await };
378 #[cfg(feature = "tracing")]
379 let task = task
380 .instrument(info_span!("immediate_load"))
381 .in_current_span();
382
383 (self.spawner)(Box::pin(task));
384 }
385 Action::StartFetch => {
386 let inner = self.inner.clone();
387 let disable_cache = self.disable_cache.load(Ordering::SeqCst);
388 let delay = self.delay;
389
390 let task = async move {
391 Delay::new(delay).await;
392
393 let keys = {
394 let mut request = inner.requests.lock().unwrap();
395 let typed_requests = request
396 .get_mut(&tid)
397 .unwrap()
398 .downcast_mut::<Requests<K, T>>()
399 .unwrap();
400 typed_requests.take()
401 };
402
403 if !keys.0.is_empty() {
404 inner.do_load(disable_cache, keys).await
405 }
406 };
407 #[cfg(feature = "tracing")]
408 let task = task.instrument(info_span!("start_fetch")).in_current_span();
409 (self.spawner)(Box::pin(task))
410 }
411 Action::Delay => {}
412 }
413
414 rx.await.unwrap()
415 }
416
417 #[cfg_attr(feature = "tracing", instrument(skip_all))]
422 pub async fn feed_many<K, I>(&self, values: I)
423 where
424 K: Send + Sync + Hash + Eq + Clone + 'static,
425 I: IntoIterator<Item = (K, T::Value)>,
426 T: Loader<K>,
427 {
428 let tid = TypeId::of::<K>();
429 let mut requests = self.inner.requests.lock().unwrap();
430 let typed_requests = requests
431 .entry(tid)
432 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
433 .downcast_mut::<Requests<K, T>>()
434 .unwrap();
435 for (key, value) in values {
436 typed_requests
437 .cache_storage
438 .insert(Cow::Owned(key), Cow::Owned(value));
439 }
440 }
441
442 #[cfg_attr(feature = "tracing", instrument(skip_all))]
447 pub async fn feed_one<K>(&self, key: K, value: T::Value)
448 where
449 K: Send + Sync + Hash + Eq + Clone + 'static,
450 T: Loader<K>,
451 {
452 self.feed_many(std::iter::once((key, value))).await;
453 }
454
455 #[cfg_attr(feature = "tracing", instrument(skip_all))]
460 pub fn clear<K>(&self)
461 where
462 K: Send + Sync + Hash + Eq + Clone + 'static,
463 T: Loader<K>,
464 {
465 let tid = TypeId::of::<K>();
466 let mut requests = self.inner.requests.lock().unwrap();
467 let typed_requests = requests
468 .entry(tid)
469 .or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
470 .downcast_mut::<Requests<K, T>>()
471 .unwrap();
472 typed_requests.cache_storage.clear();
473 }
474
475 pub fn get_cached_values<K>(&self) -> HashMap<K, T::Value>
477 where
478 K: Send + Sync + Hash + Eq + Clone + 'static,
479 T: Loader<K>,
480 {
481 let tid = TypeId::of::<K>();
482 let requests = self.inner.requests.lock().unwrap();
483 match requests.get(&tid) {
484 None => HashMap::new(),
485 Some(requests) => {
486 let typed_requests = requests.downcast_ref::<Requests<K, T>>().unwrap();
487 typed_requests
488 .cache_storage
489 .iter()
490 .map(|(k, v)| (k.clone(), v.clone()))
491 .collect()
492 }
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use fnv::FnvBuildHasher;
500
501 use super::*;
502
503 struct MyLoader;
504
505 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
506 impl Loader<i32> for MyLoader {
507 type Value = i32;
508 type Error = ();
509
510 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
511 assert!(keys.len() <= 10);
512 Ok(keys.iter().copied().map(|k| (k, k)).collect())
513 }
514 }
515
516 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
517 impl Loader<i64> for MyLoader {
518 type Value = i64;
519 type Error = ();
520
521 async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
522 assert!(keys.len() <= 10);
523 Ok(keys.iter().copied().map(|k| (k, k)).collect())
524 }
525 }
526
527 #[tokio::test]
528 async fn test_dataloader() {
529 let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
530 assert_eq!(
531 futures_util::future::try_join_all((0..100i32).map({
532 let loader = loader.clone();
533 move |n| {
534 let loader = loader.clone();
535 async move { loader.load_one(n).await }
536 }
537 }))
538 .await
539 .unwrap(),
540 (0..100).map(Option::Some).collect::<Vec<_>>()
541 );
542
543 assert_eq!(
544 futures_util::future::try_join_all((0..100i64).map({
545 let loader = loader.clone();
546 move |n| {
547 let loader = loader.clone();
548 async move { loader.load_one(n).await }
549 }
550 }))
551 .await
552 .unwrap(),
553 (0..100).map(Option::Some).collect::<Vec<_>>()
554 );
555 }
556
557 #[tokio::test]
558 async fn test_duplicate_keys() {
559 let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
560 assert_eq!(
561 futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
562 let loader = loader.clone();
563 move |n| {
564 let loader = loader.clone();
565 async move { loader.load_one(n).await }
566 }
567 }))
568 .await
569 .unwrap(),
570 [1, 3, 5, 1, 7, 8, 3, 7]
571 .iter()
572 .copied()
573 .map(Option::Some)
574 .collect::<Vec<_>>()
575 );
576 }
577
578 #[tokio::test]
579 async fn test_dataloader_load_empty() {
580 let loader = DataLoader::new(MyLoader, tokio::spawn);
581 assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
582 }
583
584 #[tokio::test]
585 async fn test_dataloader_with_cache() {
586 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
587 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
588
589 assert_eq!(
591 loader.load_many(vec![1, 2, 3]).await.unwrap(),
592 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
593 );
594
595 assert_eq!(
597 loader.load_many(vec![1, 5, 6]).await.unwrap(),
598 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
599 );
600
601 assert_eq!(
603 loader.load_many(vec![8, 9, 10]).await.unwrap(),
604 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
605 );
606
607 loader.clear::<i32>();
609 assert_eq!(
610 loader.load_many(vec![1, 2, 3]).await.unwrap(),
611 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
612 );
613 }
614
615 #[tokio::test]
616 async fn test_dataloader_with_cache_hashmap_fnv() {
617 let loader = DataLoader::with_cache(
618 MyLoader,
619 tokio::spawn,
620 HashMapCache::<FnvBuildHasher>::new(),
621 );
622 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
623
624 assert_eq!(
626 loader.load_many(vec![1, 2, 3]).await.unwrap(),
627 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
628 );
629
630 assert_eq!(
632 loader.load_many(vec![1, 5, 6]).await.unwrap(),
633 vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
634 );
635
636 assert_eq!(
638 loader.load_many(vec![8, 9, 10]).await.unwrap(),
639 vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
640 );
641
642 loader.clear::<i32>();
644 assert_eq!(
645 loader.load_many(vec![1, 2, 3]).await.unwrap(),
646 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
647 );
648 }
649
650 #[tokio::test]
651 async fn test_dataloader_disable_all_cache() {
652 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
653 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
654
655 loader.enable_all_cache(false);
657 assert_eq!(
658 loader.load_many(vec![1, 2, 3]).await.unwrap(),
659 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
660 );
661
662 loader.enable_all_cache(true);
664 assert_eq!(
665 loader.load_many(vec![1, 2, 3]).await.unwrap(),
666 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
667 );
668 }
669
670 #[tokio::test]
671 async fn test_dataloader_disable_cache() {
672 let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
673 loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
674
675 loader.enable_cache::<i32>(false);
677 assert_eq!(
678 loader.load_many(vec![1, 2, 3]).await.unwrap(),
679 vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
680 );
681
682 loader.enable_cache::<i32>(true);
684 assert_eq!(
685 loader.load_many(vec![1, 2, 3]).await.unwrap(),
686 vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
687 );
688 }
689
690 #[tokio::test]
691 async fn test_dataloader_dead_lock() {
692 struct MyDelayLoader;
693
694 #[cfg_attr(feature = "boxed-trait", async_trait::async_trait)]
695 impl Loader<i32> for MyDelayLoader {
696 type Value = i32;
697 type Error = ();
698
699 async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
700 tokio::time::sleep(Duration::from_secs(1)).await;
701 Ok(keys.iter().copied().map(|k| (k, k)).collect())
702 }
703 }
704
705 let loader = Arc::new(
706 DataLoader::with_cache(MyDelayLoader, tokio::spawn, NoCache)
707 .delay(Duration::from_secs(1)),
708 );
709 let handle = tokio::spawn({
710 let loader = loader.clone();
711 async move {
712 loader.load_many(vec![1, 2, 3]).await.unwrap();
713 }
714 });
715
716 tokio::time::sleep(Duration::from_millis(500)).await;
717 handle.abort();
718 loader.load_many(vec![4, 5, 6]).await.unwrap();
719 }
720}