aws_smithy_runtime/
static_partition_map.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use once_cell::sync::OnceCell;
7use std::collections::HashMap;
8use std::hash::Hash;
9use std::sync::{Mutex, MutexGuard};
10
11/// A data structure for persisting and sharing state between multiple clients.
12///
13/// Some state should be shared between multiple clients. For example, when creating multiple clients
14/// for the same service, it's desirable to share a client rate limiter. This way, when one client
15/// receives a throttling response, the other clients will be aware of it as well.
16///
17/// Whether clients share state is dependent on their partition key `K`. Going back to the client
18/// rate limiter example, `K` would be a struct containing the name of the service as well as the
19/// client's configured region, since receiving throttling responses in `us-east-1` shouldn't
20/// throttle requests to the same service made in other regions.
21///
22/// Values stored in a `StaticPartitionMap` will be cloned whenever they are requested. Values must
23/// be initialized before they can be retrieved, and the `StaticPartitionMap::get_or_init` method is
24/// how you can ensure this.
25///
26/// # Example
27///
28/// ```
29///use std::sync::{Arc, Mutex};
30/// use aws_smithy_runtime::static_partition_map::StaticPartitionMap;
31///
32/// // The shared state must be `Clone` and will be internally mutable. Deriving `Default` isn't
33/// // necessary, but allows us to use the `StaticPartitionMap::get_or_init_default` method.
34/// #[derive(Clone, Default)]
35/// pub struct SomeSharedState {
36///     inner: Arc<Mutex<Inner>>
37/// }
38///
39/// #[derive(Default)]
40/// struct Inner {
41///     // Some shared state...
42/// }
43///
44/// // `Clone`, `Hash`, and `Eq` are all required trait impls for partition keys
45/// #[derive(Clone, Hash, PartialEq, Eq)]
46/// pub struct SharedStatePartition {
47///     region: String,
48///     service_name: String,
49/// }
50///
51/// impl SharedStatePartition {
52///     pub fn new(region: impl Into<String>, service_name: impl Into<String>) -> Self {
53///         Self { region: region.into(), service_name: service_name.into() }
54///     }
55/// }
56///
57/// static SOME_SHARED_STATE: StaticPartitionMap<SharedStatePartition, SomeSharedState> = StaticPartitionMap::new();
58///
59/// struct Client {
60///     shared_state: SomeSharedState,
61/// }
62///
63/// impl Client {
64///     pub fn new() -> Self {
65///         let key = SharedStatePartition::new("us-east-1", "example_service_20230628");
66///         Self {
67///             // If the stored value implements `Default`, you can call the
68///             // `StaticPartitionMap::get_or_init_default` convenience method.
69///             shared_state: SOME_SHARED_STATE.get_or_init_default(key),
70///         }
71///     }
72/// }
73/// ```
74#[derive(Debug, Default)]
75pub struct StaticPartitionMap<K, V> {
76    inner: OnceCell<Mutex<HashMap<K, V>>>,
77}
78
79impl<K, V> StaticPartitionMap<K, V> {
80    /// Creates a new `StaticPartitionMap`.
81    pub const fn new() -> Self {
82        Self {
83            inner: OnceCell::new(),
84        }
85    }
86}
87
88impl<K, V> StaticPartitionMap<K, V>
89where
90    K: Eq + Hash,
91{
92    fn get_or_init_inner(&self) -> MutexGuard<'_, HashMap<K, V>> {
93        self.inner
94            // At the very least, we'll always be storing the default state.
95            .get_or_init(|| Mutex::new(HashMap::with_capacity(1)))
96            .lock()
97            .unwrap()
98    }
99}
100
101impl<K, V> StaticPartitionMap<K, V>
102where
103    K: Eq + Hash,
104    V: Clone,
105{
106    /// Gets the value for the given partition key.
107    #[must_use]
108    pub fn get(&self, partition_key: K) -> Option<V> {
109        self.get_or_init_inner().get(&partition_key).cloned()
110    }
111
112    /// Gets the value for the given partition key, initializing it with `init` if it doesn't exist.
113    #[must_use]
114    pub fn get_or_init<F>(&self, partition_key: K, init: F) -> V
115    where
116        F: FnOnce() -> V,
117    {
118        let mut inner = self.get_or_init_inner();
119        let v = inner.entry(partition_key).or_insert_with(init);
120        v.clone()
121    }
122}
123
124impl<K, V> StaticPartitionMap<K, V>
125where
126    K: Eq + Hash,
127    V: Clone + Default,
128{
129    /// Gets the value for the given partition key, initializing it if it doesn't exist.
130    #[must_use]
131    pub fn get_or_init_default(&self, partition_key: K) -> V {
132        self.get_or_init(partition_key, V::default)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::StaticPartitionMap;
139
140    #[test]
141    fn test_keyed_partition_returns_same_value_for_same_key() {
142        let kp = StaticPartitionMap::new();
143        let _ = kp.get_or_init("A", || "A".to_owned());
144        let actual = kp.get_or_init("A", || "B".to_owned());
145        let expected = "A".to_owned();
146        assert_eq!(expected, actual);
147    }
148
149    #[test]
150    fn test_keyed_partition_returns_different_value_for_different_key() {
151        let kp = StaticPartitionMap::new();
152        let _ = kp.get_or_init("A", || "A".to_owned());
153        let actual = kp.get_or_init("B", || "B".to_owned());
154
155        let expected = "B".to_owned();
156        assert_eq!(expected, actual);
157
158        let actual = kp.get("A").unwrap();
159        let expected = "A".to_owned();
160        assert_eq!(expected, actual);
161    }
162}