aws_smithy_runtime/static_partition_map.rs
1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use once_cell::sync::OnceCell;
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::sync::{Mutex, MutexGuard};
10
11/// A data structure for persisting and sharing state between multiple clients.
12///
13/// Some state should be shared between multiple clients. For example, when creating multiple clients
14/// for the same service, it's desirable to share a client rate limiter. This way, when one client
15/// receives a throttling response, the other clients will be aware of it as well.
16///
17/// Whether clients share state is dependent on their partition key `K`. Going back to the client
18/// rate limiter example, `K` would be a struct containing the name of the service as well as the
19/// client's configured region, since receiving throttling responses in `us-east-1` shouldn't
20/// throttle requests to the same service made in other regions.
21///
22/// Values stored in a `StaticPartitionMap` will be cloned whenever they are requested. Values must
23/// be initialized before they can be retrieved, and the `StaticPartitionMap::get_or_init` method is
24/// how you can ensure this.
25///
26/// # Example
27///
28/// ```
29///use std::sync::{Arc, Mutex};
30/// use aws_smithy_runtime::static_partition_map::StaticPartitionMap;
31///
32/// // The shared state must be `Clone` and will be internally mutable. Deriving `Default` isn't
33/// // necessary, but allows us to use the `StaticPartitionMap::get_or_init_default` method.
34/// #[derive(Clone, Default)]
35/// pub struct SomeSharedState {
36/// inner: Arc<Mutex<Inner>>
37/// }
38///
39/// #[derive(Default)]
40/// struct Inner {
41/// // Some shared state...
42/// }
43///
44/// // `Clone`, `Hash`, and `Eq` are all required trait impls for partition keys
45/// #[derive(Clone, Hash, PartialEq, Eq)]
46/// pub struct SharedStatePartition {
47/// region: String,
48/// service_name: String,
49/// }
50///
51/// impl SharedStatePartition {
52/// pub fn new(region: impl Into<String>, service_name: impl Into<String>) -> Self {
53/// Self { region: region.into(), service_name: service_name.into() }
54/// }
55/// }
56///
57/// static SOME_SHARED_STATE: StaticPartitionMap<SharedStatePartition, SomeSharedState> = StaticPartitionMap::new();
58///
59/// struct Client {
60/// shared_state: SomeSharedState,
61/// }
62///
63/// impl Client {
64/// pub fn new() -> Self {
65/// let key = SharedStatePartition::new("us-east-1", "example_service_20230628");
66/// Self {
67/// // If the stored value implements `Default`, you can call the
68/// // `StaticPartitionMap::get_or_init_default` convenience method.
69/// shared_state: SOME_SHARED_STATE.get_or_init_default(key),
70/// }
71/// }
72/// }
73/// ```
74#[derive(Debug, Default)]
75pub struct StaticPartitionMap<K, V> {
76 inner: OnceCell<Mutex<HashMap<K, V>>>,
77}
78
79impl<K, V> StaticPartitionMap<K, V> {
80 /// Creates a new `StaticPartitionMap`.
81 pub const fn new() -> Self {
82 Self {
83 inner: OnceCell::new(),
84 }
85 }
86}
87
88impl<K, V> StaticPartitionMap<K, V>
89where
90 K: Eq + Hash,
91{
92 fn get_or_init_inner(&self) -> MutexGuard<'_, HashMap<K, V>> {
93 self.inner
94 // At the very least, we'll always be storing the default state.
95 .get_or_init(|| Mutex::new(HashMap::with_capacity(1)))
96 .lock()
97 .unwrap()
98 }
99}
100
101impl<K, V> StaticPartitionMap<K, V>
102where
103 K: Eq + Hash,
104 V: Clone,
105{
106 /// Gets the value for the given partition key.
107 #[must_use]
108 pub fn get(&self, partition_key: K) -> Option<V> {
109 self.get_or_init_inner().get(&partition_key).cloned()
110 }
111
112 /// Gets the value for the given partition key, initializing it with `init` if it doesn't exist.
113 #[must_use]
114 pub fn get_or_init<F>(&self, partition_key: K, init: F) -> V
115 where
116 F: FnOnce() -> V,
117 {
118 let mut inner = self.get_or_init_inner();
119 let v = inner.entry(partition_key).or_insert_with(init);
120 v.clone()
121 }
122}
123
124impl<K, V> StaticPartitionMap<K, V>
125where
126 K: Eq + Hash,
127 V: Clone + Default,
128{
129 /// Gets the value for the given partition key, initializing it if it doesn't exist.
130 #[must_use]
131 pub fn get_or_init_default(&self, partition_key: K) -> V {
132 self.get_or_init(partition_key, V::default)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::StaticPartitionMap;
139
140 #[test]
141 fn test_keyed_partition_returns_same_value_for_same_key() {
142 let kp = StaticPartitionMap::new();
143 let _ = kp.get_or_init("A", || "A".to_owned());
144 let actual = kp.get_or_init("A", || "B".to_owned());
145 let expected = "A".to_owned();
146 assert_eq!(expected, actual);
147 }
148
149 #[test]
150 fn test_keyed_partition_returns_different_value_for_different_key() {
151 let kp = StaticPartitionMap::new();
152 let _ = kp.get_or_init("A", || "A".to_owned());
153 let actual = kp.get_or_init("B", || "B".to_owned());
154
155 let expected = "B".to_owned();
156 assert_eq!(expected, actual);
157
158 let actual = kp.get("A").unwrap();
159 let expected = "A".to_owned();
160 assert_eq!(expected, actual);
161 }
162}