atuin_client/
kv.rs

1use std::collections::BTreeMap;
2
3use atuin_common::record::{DecryptedData, Host, HostId};
4use eyre::{bail, ensure, eyre, Result};
5use serde::Deserialize;
6
7use crate::record::encryption::PASETO_V4;
8use crate::record::store::Store;
9
10const KV_VERSION: &str = "v0";
11const KV_TAG: &str = "kv";
12const KV_VAL_MAX_LEN: usize = 100 * 1024;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KvRecord {
16    pub namespace: String,
17    pub key: String,
18    pub value: String,
19}
20
21impl KvRecord {
22    pub fn serialize(&self) -> Result<DecryptedData> {
23        use rmp::encode;
24
25        let mut output = vec![];
26
27        // INFO: ensure this is updated when adding new fields
28        encode::write_array_len(&mut output, 3)?;
29
30        encode::write_str(&mut output, &self.namespace)?;
31        encode::write_str(&mut output, &self.key)?;
32        encode::write_str(&mut output, &self.value)?;
33
34        Ok(DecryptedData(output))
35    }
36
37    pub fn deserialize(data: &DecryptedData, version: &str) -> Result<Self> {
38        use rmp::decode;
39
40        fn error_report<E: std::fmt::Debug>(err: E) -> eyre::Report {
41            eyre!("{err:?}")
42        }
43
44        match version {
45            KV_VERSION => {
46                let mut bytes = decode::Bytes::new(&data.0);
47
48                let nfields = decode::read_array_len(&mut bytes).map_err(error_report)?;
49                ensure!(nfields == 3, "too many entries in v0 kv record");
50
51                let bytes = bytes.remaining_slice();
52
53                let (namespace, bytes) =
54                    decode::read_str_from_slice(bytes).map_err(error_report)?;
55                let (key, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
56                let (value, bytes) = decode::read_str_from_slice(bytes).map_err(error_report)?;
57
58                if !bytes.is_empty() {
59                    bail!("trailing bytes in encoded kvrecord. malformed")
60                }
61
62                Ok(KvRecord {
63                    namespace: namespace.to_owned(),
64                    key: key.to_owned(),
65                    value: value.to_owned(),
66                })
67            }
68            _ => {
69                bail!("unknown version {version:?}")
70            }
71        }
72    }
73}
74
75#[derive(Debug, Clone, Deserialize)]
76pub struct KvStore;
77
78impl Default for KvStore {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl KvStore {
85    // will want to init the actual kv store when that is done
86    pub fn new() -> KvStore {
87        KvStore {}
88    }
89
90    pub async fn set(
91        &self,
92        store: &(impl Store + Send + Sync),
93        encryption_key: &[u8; 32],
94        host_id: HostId,
95        namespace: &str,
96        key: &str,
97        value: &str,
98    ) -> Result<()> {
99        if value.len() > KV_VAL_MAX_LEN {
100            return Err(eyre!(
101                "kv value too large: max len {} bytes",
102                KV_VAL_MAX_LEN
103            ));
104        }
105
106        let record = KvRecord {
107            namespace: namespace.to_string(),
108            key: key.to_string(),
109            value: value.to_string(),
110        };
111
112        let bytes = record.serialize()?;
113
114        let idx = store
115            .last(host_id, KV_TAG)
116            .await?
117            .map_or(0, |entry| entry.idx + 1);
118
119        let record = atuin_common::record::Record::builder()
120            .host(Host::new(host_id))
121            .version(KV_VERSION.to_string())
122            .tag(KV_TAG.to_string())
123            .idx(idx)
124            .data(bytes)
125            .build();
126
127        store
128            .push(&record.encrypt::<PASETO_V4>(encryption_key))
129            .await?;
130
131        Ok(())
132    }
133
134    // TODO: setup an actual kv store, rebuild func, and do not pass the main store in here as
135    // well.
136    pub async fn get(
137        &self,
138        store: &impl Store,
139        encryption_key: &[u8; 32],
140        namespace: &str,
141        key: &str,
142    ) -> Result<Option<KvRecord>> {
143        // TODO: don't rebuild every time...
144        let map = self.build_kv(store, encryption_key).await?;
145
146        let res = map.get(namespace);
147
148        if let Some(ns) = res {
149            let value = ns.get(key);
150
151            Ok(value.cloned())
152        } else {
153            Ok(None)
154        }
155    }
156
157    // Build a kv map out of the linked list kv store
158    // Map is Namespace -> Key -> Value
159    // TODO(ellie): "cache" this into a real kv structure, which we can
160    // use as a write-through cache to avoid constant rebuilds.
161    pub async fn build_kv(
162        &self,
163        store: &impl Store,
164        encryption_key: &[u8; 32],
165    ) -> Result<BTreeMap<String, BTreeMap<String, KvRecord>>> {
166        let mut map = BTreeMap::new();
167
168        // TODO: maybe don't load the entire tag into memory to build the kv
169        // we can be smart about it and only load values since the last build
170        // or, iterate/paginate
171        let tagged = store.all_tagged(KV_TAG).await?;
172
173        // iterate through all tags and play each KV record at a time
174        // this is "last write wins"
175        // probably good enough for now, but revisit in future
176        for record in tagged {
177            let decrypted = match record.version.as_str() {
178                KV_VERSION => record.decrypt::<PASETO_V4>(encryption_key)?,
179                version => bail!("unknown version {version:?}"),
180            };
181
182            let kv = KvRecord::deserialize(&decrypted.data, KV_VERSION)?;
183
184            let ns = map
185                .entry(kv.namespace.clone())
186                .or_insert_with(BTreeMap::new);
187
188            ns.insert(kv.key.clone(), kv);
189        }
190
191        Ok(map)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use crypto_secretbox::{KeyInit, XSalsa20Poly1305};
198    use rand::rngs::OsRng;
199
200    use crate::record::sqlite_store::SqliteStore;
201    use crate::settings::test_local_timeout;
202
203    use super::{KvRecord, KvStore, KV_VERSION};
204
205    #[test]
206    fn encode_decode() {
207        let kv = KvRecord {
208            namespace: "foo".to_owned(),
209            key: "bar".to_owned(),
210            value: "baz".to_owned(),
211        };
212        let snapshot = [
213            0x93, 0xa3, b'f', b'o', b'o', 0xa3, b'b', b'a', b'r', 0xa3, b'b', b'a', b'z',
214        ];
215
216        let encoded = kv.serialize().unwrap();
217        let decoded = KvRecord::deserialize(&encoded, KV_VERSION).unwrap();
218
219        assert_eq!(encoded.0, &snapshot);
220        assert_eq!(decoded, kv);
221    }
222
223    #[tokio::test]
224    async fn build_kv() {
225        let mut store = SqliteStore::new(":memory:", test_local_timeout())
226            .await
227            .unwrap();
228        let kv = KvStore::new();
229        let key: [u8; 32] = XSalsa20Poly1305::generate_key(&mut OsRng).into();
230        let host_id = atuin_common::record::HostId(atuin_common::utils::uuid_v7());
231
232        kv.set(&mut store, &key, host_id, "test-kv", "foo", "bar")
233            .await
234            .unwrap();
235
236        kv.set(&mut store, &key, host_id, "test-kv", "1", "2")
237            .await
238            .unwrap();
239
240        let map = kv.build_kv(&store, &key).await.unwrap();
241
242        assert_eq!(
243            *map.get("test-kv")
244                .expect("map namespace not set")
245                .get("foo")
246                .expect("map key not set"),
247            KvRecord {
248                namespace: String::from("test-kv"),
249                key: String::from("foo"),
250                value: String::from("bar")
251            }
252        );
253
254        assert_eq!(
255            *map.get("test-kv")
256                .expect("map namespace not set")
257                .get("1")
258                .expect("map key not set"),
259            KvRecord {
260                namespace: String::from("test-kv"),
261                key: String::from("1"),
262                value: String::from("2")
263            }
264        );
265    }
266}